import os

import gym
import torch

from agents.interfaces import LearnerAlone
from tools.envs import make_dispatch
from tools.evaluators import generate_gwr_goals, image_env
from tools.utils import preprocess
import numpy as np


def data_generation(size=20):
    #normalized afterward by the preproces function
    data = torch.zeros((size,size, size*size),dtype=torch.uint8)
    for i in range(size):
        for j in range(size):
            data[i, j,size*i+j] = 255
    return data.view(-1,size*size)


def data_generation2(size=20,avoid_walls=False,**kwargs):
    std = torch.tensor([(size / 2) / 3., (size / 2) / 3.])
    mean = torch.tensor([size / 2, size / 2])
    data = torch.zeros((size,size, 2))
    from envs.empty import open_file
    walls = torch.tensor(open_file("hard_env.txt"), dtype=torch.float)
    nowalls = walls.reshape(-1) == 0
    for i in range(size):
        for j in range(size):
            data[i, j] = torch.tensor([(i - mean[0]) / std[0], (j - mean[1]) / std[1]])
    if avoid_walls:
        return data.view(-1,2)[nowalls]

    return data.view(-1,2)

def data_generation3(size=20,**kwargs):
    # if avoid_walls:
    from envs.empty import open_file
    walls = torch.tensor(open_file("hard_env.txt"), dtype=torch.float)
    data = torch.zeros((size,size, size*size),dtype=torch.uint8)
    for i in range(size):
        for j in range(size):
            if walls[i,j] != 1:
                data[i, j,size*i+j] = 255

    return data.view(-1,size*size),(walls != 1)


def evaluate_gridworld(env_type, env_name, sublearner, save_dir, goal_space,args, predictor=None, coordpolicy=None,
                         all_states=False, context=None, **kwargs):
    model_subpolicy = sublearner.mpolicy
    if not os.path.exists(save_dir): os.mkdir(save_dir)
    policy_save_dir = save_dir + "/end"
    if not os.path.exists(policy_save_dir): os.mkdir(policy_save_dir)
    all_states_save_dir = save_dir + "/all_states"
    if not os.path.exists(all_states_save_dir): os.mkdir(all_states_save_dir)
    predictor_save_dir = save_dir + "/rewards"
    if not os.path.exists(predictor_save_dir): os.mkdir(predictor_save_dir)
    size=30
    if args.image:
        data = preprocess(data_generation(size), args)
    else:
        data = data_generation2(size)

    all_goals, all_index_goals, all_index_encodes, all_index_goals_incorrect = generate_gwr_goals(coordpolicy,args)
    for i in range(all_goals.shape[0]):
        goals = all_index_encodes[i:i + 1]
        goals_decode = all_index_goals[i:i + 1]

        #####Evaluation
        #make_dispatch(args.env_type,args,save_dir,video=False,reset=(False if args.env_name == "sawyer_door" else True))
        envs = make_dispatch(args.env_type,args,save_dir,video=False,random_pos=False,stats=True, render_stats=False,goal=goals[0,0].numpy())
        eval_subpolicy = model_subpolicy.clone(envs=envs)
        sublearner2 = LearnerAlone(eval_subpolicy, sublearner.rew_modules, [], [])
        sublearner2.eval()
        obs=envs.reset()
        for _ in range(args.max_steps*5):
            obs = obs["observation"] if isinstance(envs.observation_space,gym.spaces.Dict) else obs
            goal,_ = coordpolicy.act(obs,predefined_goal=all_goals[i:i+1])
            act=sublearner.mpolicy.algo.act(obs,goals=goal,use_goals=True)
            obs, _, done, infos = envs.step(act)
            if done:
                obs = envs.reset()

        envs.close()

        ### Display the gathered statistics
        goalpos_gridworld(infos,goals,policy_save_dir,goals_decode,args)
    # args, data, predictor, goals, true_goals, save_dir = None, index = None, size = 20, ** kwargs)
    compute_reward(args, data,predictor,goals, all_goals, predictor_save_dir,size=size)
    goal_file = open(save_dir + "/test_goals.txt", "w")
    for row in all_goals:
        np.savetxt(goal_file, row)
    goal_file.close()

    if args.image:
        data, mask = data_generation3(size)
    else:
        # data = data_generation2(size)
        _,mask = data_generation3(size)
    data = preprocess(data, args).float()
    with torch.no_grad():
        probs_all(env_type, env_name, data, predictor, mpolicy=model_subpolicy, mask=mask,
                  save_dir=all_states_save_dir, size=size, subproc=False, all_states=all_states, args=args)
    return


def probs_all(env_type,env_name,data,predictor,mpolicy=None, save_dir=None,size=20,all_states=False,args=None,mask=None,**kwargs):
    if save_dir is None:
        return
    all_hot_goals2=predictor.label_embed(data,act=True).cpu()
    all_hot_goals = all_hot_goals2.view(size,size,-1)


    vals=all_hot_goals.detach().cpu().numpy()
    if vals.shape[2] == 2:
        plot_topology(vals,save_dir)
    else:
        plot_topology3d(vals,save_dir,mask=mask)
    # all_goals=torch.arange(size*size).view((-1,1,1))
    # index=0
    # i,j = 0,0
    # if all_states:
    #     with torch.no_grad():
    #         while i < size:
    #             j=0
    #             while j < size:
    #                 hot_goals = all_hot_goals[i,j].view(1,-1)
    #                 goals=all_goals[j+size*i]
    #                 m_pb, _= predictor.eval(data, hot_goals, index + 1)
    #                 statistics = m_pb.view(size, size, 1).cpu().numpy()  # m_pb.mul(dim=1).numpy()
    #                 statistics = statistics.astype(np.float32)
    #                 min = np.min(statistics)
    #                 max = np.max(statistics)
    #                 statistics = (statistics - min) / ((max - min) if max-min > 0 else 1)
    #
    #                 envs = make_dispatch(env_type, env_name, 1, mpolicy.device,
    #                                      goals=goals[:, index], stats=statistics, render_stats=True, one_goal=True, size=size,args=args,
    #                                      **kwargs)
    #                 image_env(envs, index, goals, save_dir)
    #                 envs.close()
    #                 j+=3
    #             i+=3



def goalpos_gridworld(infos,goals,save_dir,goals_decode=None,args=None,**kwargs):
    #####Get agrgegate and the statistics on goals
    statistics = None
    if 'stats' in infos:
        # statistics = np.array(infos["stats"],dtype=int).sum(axis=2,keepdims=True)
        statistics = np.array(infos["stats"],dtype=int)
    envs = make_dispatch(args.env_type, args, save_dir, video=False, goal=goals[0,0].cpu().numpy(), stats=statistics, render_stats=True,one_goal=True,**kwargs)
    if save_dir is not None:
        image_env(envs, 0, goals_decode, save_dir)
    else:
        envs.render(mode='human')
    envs.close()

def compute_reward(args,data,predictor, goals, true_goals, save_dir=None,size=20,**kwargs):
    with torch.no_grad():
        m_pb = predictor.evaluate(data, true_goals)
        statistics = m_pb.view(size, size, 1).cpu().numpy()  # m_pb.mul(dim=1).numpy()
        statistics = statistics.astype(np.float32)
        min = np.min(statistics)
        max = np.max(statistics)
        statistics = (statistics - min) / ((max - min) if max-min > 0 else 1)
        envs = make_dispatch(args.env_type, args, save_dir, video=False, goal=goals[0,0].cpu().numpy(), stats=statistics, render_stats=True,one_goal=True,**kwargs)
        image_env(envs, 0, goals, save_dir,args=args)
        envs.close()

def plot_topology3d(data_view,file_path,mask=None):
    if data_view.shape[2] != 3:
        return
    import matplotlib.pyplot as plt
    if mask is not None:
        mask_1d = mask.reshape(-1)
    data=data_view.reshape(-1,data_view.shape[2])
    #https://www.idtools.com.au/3d-network-graphs-python-mplot3d-toolkit/
    # 3D network plot
    with plt.style.context(('ggplot')):
        # fig = plt.figure(figsize=plt.figaspect(0.5))
        fig = plt.figure(figsize=(10, 10))
        from mpl_toolkits.mplot3d import Axes3D
        ax = Axes3D(fig)
        # ax.set_xlim3d(np.min(data[:,0]),np.max(data[:,0]))
        # ax.set_ylim3d(np.min(data[:,1]),np.max(data[:,1]))
        # ax.set_zlim3d(np.min(data[:,2]),np.max(data[:,2]))
        ax.set_xlim3d(-3, 3)
        ax.set_ylim3d(-3, 3)
        ax.set_zlim3d(-3, 3)

        # Loop on the pos dictionary to extract the x,y,z coordinates of each node
        for n in range(data.shape[0]):
            if mask is None or mask_1d[n]:
                value=data[n,:]
                xi = value[0].item()
                yi = value[1].item()
                zi = value[2].item()

                # Scatter plot
                ax.scatter(xi, yi, zi,  s=5, c="blue", edgecolors='k', alpha=0.5)

        # Loop on the list of edges to get the x,y,z, coordinates of the connected nodes
        # Those two points are the extrema of the line to be plotted
        for i in range(data_view.shape[0]):
            for j in range(data_view.shape[0]):
                # Plot the connecting lines
                if j != data_view.shape[0] - 1:
                    if mask is None or (mask[i,j] and mask[i,j+1]):
                        ax.plot(data_view[i, j:j + 2, 0], data_view[i, j:j + 2, 1],data_view[i, j:j + 2, 2], c='blue', alpha=0.5)
                if i != data_view.shape[0] - 1:
                    if mask is None or (mask[i,j] and mask[i+1,j]):
                        ax.plot(data_view[i:i + 2, j, 0], data_view[i:i + 2, j, 1],data_view[i:i + 2, j, 2], c='blue', alpha=0.5)
    # Set the initial view
    if file_path:
        plt.savefig(file_path+"_topology.png")
        plt.close('all')
    return

def plot_topology(data_view,file_path):
    if data_view.shape[2] != 2:
        return
    import matplotlib.pyplot as plt
    data=data_view.reshape(-1,data_view.shape[2])
    fig, ax = plt.subplots()
    fig.set_size_inches(10.0, 10.0)
    plt.xlim(np.min(data[:,0]),np.max(data[:,0]))
    plt.ylim(np.min(data[:,1]),np.max(data[:,1]))

    # Scatter plot
    ax.scatter(data[:,0],data[:,1],  s=5, c="blue", edgecolors='k', alpha=0.5)
    ax.text(data_view[0, 0, 0], data_view[0, 0, 1], "0-0")
    ax.text(data_view[0, size-1, 0], data_view[0, size-1, 1], "0-"+str(size-1))
    ax.text(data_view[size-1, 0, 0], data_view[size-1, 0, 1], str(size-1)+"-0")
    ax.text(data_view[size-1, size-1, 0], data_view[size-1, size-1, 1], str(size-1)+"-"+str(size-1))
    ax.text(data_view[size//2, size//2, 0], data_view[size//2, size//2, 1], str(size//2)+"-"+str(size//2))

    # Loop on the list of edges to get the x,y,z, coordinates of the connected nodes
    # Those two points are the extrema of the line to be plotted
    for i in range(data_view.shape[0]):
        for j in range(data_view.shape[0]):
            # Plot the connecting lines
            if j != data_view.shape[0]-1:
                ax.plot(data_view[i,j:j+2,0], data_view[i,j:j+2,1], c='blue', alpha=0.5)
            if i != data_view.shape[0]-1:
                ax.plot(data_view[i:i+2, j, 0], data_view[i:i+2, j, 1], c='blue', alpha=0.5)
    # Set the initial view
    if file_path:
        plt.savefig(file_path+"_topology.png")
        plt.close('all')
    return